import sys

sys.path.append("./")

import json
import os
import random
from multiprocessing import Pool

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms


class Clothing1mPP(Dataset):
    def __init__(
        self,
        root,
        image_size,
        full_label=False,
        split="train",
        pre_load=None,
        transform=None,
        clean_label=False,
        turk_label=False,
    ):
        print(f"######### Start Loading Clothing1mPP dataset: {split} #########")
        self.root = root
        self.label_json_path = os.path.join(
            self.root, "splits_labels", "labels_meta.json"
        )
        self.split_path = os.path.join(
            os.path.dirname(os.path.realpath(__file__)), "bin", "cloth1m_data_v3"
        )
        self.turk_label_json_path = os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            "bin",
            "cloth1m_data_v3",
            "turk_qualification_check_20k_launch_updated.json",
        )
        self.full_label = full_label
        self.turk_label = turk_label
        self.image_size = image_size
        self.split = split
        self.transform = self.get_transform(transform)
        self.load_data_package(pre_load)
        self.split_train_val_test(clean_label)
        self.num_classes = len(np.unique(self.targets))
        self.clean_id = torch.load(
            os.path.join(
                os.path.dirname(os.path.realpath(__file__)),
                "bin",
                "cloth1m_data_v3",
                "train_ids_clean_final.pt",
            )
        )
        print(f"######### Done Loading Clothing1mPP dataset: {split} #########\n\n")

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        label = self.targets[idx]
        attribute = self.attributes[idx]
        image = self.load_image(self.data[idx])
        image = self.transform(image)
        assert image.shape == (
            3,
            self.image_size,
            self.image_size,
        ), f"image shape: {image.shape}"

        if self.full_label:
            return image, label, attribute, idx
        else:
            return image, label, idx

    def get_tiny_ids(self, seed):
        tiny_id = []
        random.seed(seed)
        for i in range(self.num_classes):
            tiny_id.extend(
                random.choices(np.where(self.targets == i)[0], k=4167)
            )  # k is hard coded here
        random.shuffle(tiny_id)
        return tiny_id

    def get_cls_num_list(self, seed, imbalance_factor):
        targets = self.targets

        random.seed(seed)  # Set the random seed for reproducibility

        label_counts = np.bincount(targets)  # Count the occurrences of each label
        num_classes = len(label_counts)
        min_samples_per_class = np.min(label_counts)

        img_num_per_cls = []

        for class_idx in range(num_classes):
            imbalance_factor_class = imbalance_factor ** (
                class_idx / (num_classes - 1.0)
            )
            num_samples_class = int(min_samples_per_class * imbalance_factor_class)
            img_num_per_cls.append(num_samples_class)

        print("img_num_per_class", img_num_per_cls)

        return img_num_per_cls

    def get_imbalance_ids(self, seed, imbalance_factor, print_info=True):
        targets = self.targets
        # assert (
        #     self.clean_label is True
        # ), "The imbalance dataset can only be used in the clean dataset"

        random.seed(seed)  # Set the random seed for reproducibility

        label_counts = np.bincount(targets)  # Count the occurrences of each label
        num_classes = len(label_counts)
        min_samples_per_class = np.min(label_counts)

        img_num_per_cls = []
        fac_list = []
        imbalance_ids = []

        for class_idx in range(num_classes):
            imbalance_factor_class = imbalance_factor ** (
                class_idx / (num_classes - 1.0)
            )
            num_samples_class = int(min_samples_per_class * imbalance_factor_class)
            fac_list.append(imbalance_factor_class)
            img_num_per_cls.append(num_samples_class)

            label_indices = np.where(targets == class_idx)[0].tolist()
            selected_indices = random.sample(label_indices, num_samples_class)
            imbalance_ids.extend(selected_indices)

        # Print out key factors
        if print_info:
            print("Applying imbalance factor   :", imbalance_factor)
            print("Minimum samples per class:  :", min_samples_per_class)
            print("Imbalance factors per class :", [round(fac, 2) for fac in fac_list])
            print("Number of samples per class :", img_num_per_cls)
            print("Original label counts       :", label_counts.tolist())
            print(
                "Total imbalance IDs         :",
                len(imbalance_ids),
                f"({len(imbalance_ids)/len(targets)*100:.2f}%)",
            )

        return imbalance_ids

    # image_path_info: [label_index_list, file_name]
    # exmple: [0, 0, 0, 0, '000000.jpg']
    def load_image(self, image_path_info):
        assert hasattr(self, "meta_data"), "meta_data is not loaded"
        image_path = self.unpack_image_path_info(image_path_info, self.meta_data)
        with Image.open(image_path) as img:
            img = img.convert("RGB")
        return img

    # path_info: [label_index_list, file_name]
    # exmple: [[0, 0, 0, 0], '000000.jpg']
    def unpack_image_path_info(self, path_info, meta_data):
        # Convert label index to label name, then to directory name
        # (int -> str)
        names = self.get_label_from_index(path_info[:-1], meta_data)
        attri_dir_name = "_".join(names[1:])
        dir_path = os.path.join(self.root, "images", f"{names[0]}", attri_dir_name)
        file_name = f"{path_info[-1]:06}.jpg"
        return os.path.join(dir_path, file_name)

    def load_data_package(self, pre_load=None):
        # cache_path = os.path.join(self.root,'cache','data_package.pt')
        # 1. Load from pre-load
        if pre_load is not None:
            print("Loading from pre-load")
            self.data_package = pre_load
        # 2. Load from scratch (Json file)
        else:
            print("Loading from scratch")
            self.data_package = self.load_clothing1mpp()

        # Intialize data
        for key, value in self.data_package.items():
            setattr(self, key, value)

    def split_train_val_test(self, clean_label=False):
        print("Splitting data")
        # Mapping split names to file names
        split_files = {
            "train": "train_ids.pt",
            "val": "val_ids.pt",
            "test": "test_ids.pt",
        }
        if clean_label:
            split_files["train"] = "train_ids_clean_final.pt"

        if not self.turk_label:
            # Load ids based on the current split
            ids = torch.load(os.path.join(self.split_path, split_files[self.split]))
            # print(self.data.shape, ids.shape)
            self.data = self.data[ids]
            self.targets = self.targets[ids]
        print(f"Done loading data, split: {self.split}. total images: {len(self.data)}")

    ################################ Load data set from scratch ################################
    # path_list: List of image paths
    # labels: np array of labels, shape: (num_images, num_labels) i.e. (1000, 4)
    # @utils.timeit
    def load_clothing1mpp(self):
        # Load the json file
        if self.turk_label:
            with open(self.turk_label_json_path, "r") as f:
                data_dict = json.load(f)
        else:
            with open(self.label_json_path, "r") as f:
                data_dict = json.load(f)

        # Parallel data loading
        with Pool(processes=2) as pool:
            results = pool.starmap(
                self.extract_data,
                [(item, data_dict["meta_data"]) for item in data_dict["labels"][:]],
            )

        # Organize results
        path_list, labels_list = zip(*results)
        labels = np.array(labels_list, dtype=np.int32)
        path_list = np.array(path_list)

        return {
            "data": path_list,
            "targets": labels[:, 0],
            "attributes": labels[:, 1:],
            "meta_data": data_dict["meta_data"],
            "classes": list(data_dict["meta_data"].keys()),
        }

    # Convert from string labels to integer labels
    def get_label(self, item, label_type, meta_data):
        """Retrieve the label index based on the label type."""
        if label_type == "Class":
            class_list = list(meta_data.keys())
            return class_list.index(item["Labels"])
        else:
            attribute_list = meta_data[item["Labels"]][label_type]
            return attribute_list.index(item["attributes"][label_type])

    def get_label_from_index(self, index_list, meta_data):
        assert len(index_list) == 4, f"index_list: {index_list}"
        out_list = []
        # Class like 'Jacket'
        class_list = list(meta_data.keys())
        class_name = class_list[index_list[0]]
        out_list.append(class_name)

        # Color, Material, Pattern
        attribute_list = meta_data[class_name]
        for i, attribute in enumerate(["Color", "Material", "Pattern"]):
            out_list.append(attribute_list[attribute][index_list[i + 1]])

        return out_list

    # Extract data from json file
    def extract_data(self, item, meta_data):
        """Load data and labels for a single item."""
        # Convert labels from string to integer
        labels = [
            self.get_label(item, key, meta_data)
            for key in ["Class", "Color", "Material", "Pattern"]
        ]

        # Convert file name to integer. e.g. '000000.jpg' -> 0
        file_name_num = int(item["file_name"].split(".")[0])
        assert (
            item["file_name"].split(".")[1] == "jpg"
        ), f"item['file_name']: {item['file_name']}"

        path_info = labels + [file_name_num]

        file_path_unpack = self.unpack_image_path_info(path_info, meta_data)
        file_path_origin = os.path.join(self.root, "images", item["file_path"][1:])
        assert (
            file_path_origin == file_path_unpack
        ), f"file_path_origin: {file_path_origin}, file_path_unpack: {file_path_unpack}"
        assert os.path.exists(file_path_origin), f"file_path_origin: {file_path_origin}"

        return path_info, labels

    ############################### End Load data set from scratch ##############################

    def get_transform(self, transform):
        if transform is None:
            if self.split == "train":
                transform = transforms.Compose(
                    [
                        transforms.Resize(
                            (self.image_size, self.image_size)
                        ),  # Resize the image to the desired crop size
                        transforms.RandomCrop(self.image_size, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            (0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)
                        ),
                    ]
                )
            else:
                transform = transforms.Compose(
                    [
                        transforms.Resize((self.image_size, self.image_size)),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            (0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)
                        ),
                    ]
                )
        return transform


if __name__ == "__main__":
    import argparse

    # Create the parser
    parser = argparse.ArgumentParser(description="Script Configuration")

    # Add arguments with existing values as defaults
    parser.add_argument(
        "--root",
        type=str,
        default="/root/cloth1m_data_v3",
        help="root directory. xxx/cloth1m_data_v3 e.g. /home/minghao/Documents/cloth1m_data_v3",
    )
    parser.add_argument(
        "--batch_size", type=int, default=32, help="batch size for processing"
    )
    parser.add_argument("--image_size", type=int, default=64, help="size of the image")
    parser.add_argument("--num_workers", type=int, default=1, help="number of workers")

    # Parse the arguments
    args = parser.parse_args()
    root, batch_size, image_size, num_workers = (
        args.root,
        args.batch_size,
        args.image_size,
        args.num_workers,
    )

    # Create tiny training dataset
    train_set = Clothing1mPP(root, image_size, split="train")
    train_set_turk = Clothing1mPP(root, image_size, split="train", turk_label=True)
    tiny_set_ids = train_set.get_tiny_ids(seed=0)
    tiny_train_set = Subset(train_set, tiny_set_ids)

    # Create imbalance training dataset
    train_set_clean = Clothing1mPP(root, image_size, split="train", clean_label=True)
    imbalance_ids = train_set_clean.get_imbalance_ids(seed=0, imbalance_factor=0.1)
    imbalance_train_set = Subset(train_set_clean, imbalance_ids)

    train_set_full = Clothing1mPP(
        root,
        image_size,
        split="train",
        full_label=True,
        pre_load=train_set.data_package,
    )
    tiny_set_full_ids = train_set_full.get_tiny_ids(seed=0)
    tiny_train_set_full = Subset(train_set_full, tiny_set_ids)

    assert np.array_equal(
        tiny_set_full_ids, tiny_set_ids
    ), "The selected ids should be the same due to the same seed"

    train_set_full_clean = Clothing1mPP(
        root,
        image_size,
        split="train",
        full_label=True,
        pre_load=train_set_clean.data_package,
        clean_label=True,
    )

    imb_set_full_ids = train_set_full_clean.get_imbalance_ids(
        seed=0, imbalance_factor=0.1
    )
    imb_train_set_full = Subset(train_set_full_clean, imbalance_ids)

    assert np.array_equal(
        imb_set_full_ids, imbalance_ids
    ), "The selected ids should be the same due to the same seed"

    val_set = Clothing1mPP(
        root, image_size, split="val", pre_load=train_set.data_package
    )
    test_set = Clothing1mPP(
        root, image_size, split="test", pre_load=train_set.data_package
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    tiny_train_loader = DataLoader(
        tiny_train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    imb_train_loader = DataLoader(
        imbalance_train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    train_full_loader = DataLoader(
        train_set_full, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    train_tiny_full_loader = DataLoader(
        tiny_train_set_full,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    print("Start running a sample train loop, break after a few batches")
    for i, (images, labels) in enumerate(train_loader):
        print(
            "Train  batch: ",
            i,
            "Total Batch Count: ",
            len(train_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
        )
        if i >= 3:
            break

    print("\nStart running a sample val loop, break after a few batches")
    for i, (images, labels) in enumerate(val_loader):
        print(
            "Val  batch: ",
            i,
            "Total Batch Count: ",
            len(train_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
        )
        if i >= 3:
            break

    print("\nStart running a sample test loop, break after a few batches")
    for i, (images, labels) in enumerate(test_loader):
        print(
            "Test  batch: ",
            i,
            "Total Batch Count: ",
            len(test_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
        )
        if i >= 3:
            break

    print("\nStart iterating the tiny training dataset")
    for i, (images, labels) in enumerate(tiny_train_loader):
        print(
            "Test  batch: ",
            i,
            "Total Batch Count: ",
            len(tiny_train_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
        )

    print("\nStart iterating the imbalance training dataset")
    for i, (images, labels) in enumerate(imb_train_loader):
        print(
            "Test  batch: ",
            i,
            "Total Batch Count: ",
            len(imb_train_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
        )

    print("\nStart iterating the training dataset which returns the full labels")
    for i, (images, labels, attributes) in enumerate(train_full_loader):
        print(
            "Test  batch: ",
            i,
            "Total Batch Count: ",
            len(train_full_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
            "attributes.shape",
            attributes.shape,
        )
        if i >= 3:
            break

    print("\nStart iterating the tiny training dataset which returns the full labels")
    for i, (images, labels, attributes) in enumerate(train_tiny_full_loader):
        print(
            "Test  batch: ",
            i,
            "Total Batch Count: ",
            len(train_tiny_full_loader),
            "images.shape: ",
            images.shape,
            "labels.shape: ",
            labels.shape,
            "attributes.shape",
            attributes.shape,
        )
